import os
import time
import torch as t
import torch.nn as nn
import torchvision as tv
import torchvision.transforms as tr
from torch.utils.data import DataLoader
import numpy as np


def norm_ip(img, min, max):
    temp = t.clamp(img, min=min, max=max)
    temp = (temp + -min) / (max - min + 1e-5)
    return temp


def cond_is_fid(f, new_buffer, test_loader, args, device='cuda', ratio=0.1, eval='all'):
    new_buffer = new_buffer.to(device)
    n_it = new_buffer.size(0) // 100
    all_y, probs = [], []
    with t.no_grad():
        for i in range(n_it):
            x = new_buffer[i * 100: (i + 1) * 100].to(device)
            logits = f(x)
            y = logits.max(1)[1]
            prob = nn.Softmax(dim=1)(logits).max(1)[0]
            all_y.append(y)
            probs.append(prob)
    all_y = t.cat(all_y, 0)
    probs = t.cat(probs, 0)
    each_class = [new_buffer[all_y == l] for l in range(args.num_labels)]
    each_class_probs = [probs[all_y == l] for l in range(args.num_labels)]
    print([len(c) for c in each_class])

    new_buffer = []
    for c in range(args.num_labels):
        each_probs = each_class_probs[c]
        # print("%d" % len(each_probs))
        if ratio < 1:
            topk = int(len(each_probs) * ratio)
        else:
            topk = int(ratio)
        topk = min(topk, len(each_probs))
        topks = t.topk(each_probs, topk)
        index_list = topks[1]
        images = each_class[c][index_list]
        new_buffer.append(images)

    new_buffer = t.cat(new_buffer, 0)
    print(new_buffer.shape)
    inc_score, std, fid = eval_is_fid(new_buffer, test_loader, eval=eval)
    # if eval in ['is', 'all']:
        # print("Inception score of {} with std of {}".format(inc_score, std))
    # if eval in ['fid', 'all']:
        # print("FID of score {}".format(fid))
    return inc_score, std, fid


def eval_is_fid(replay_buffer, test_loader, eval='all'):
    if isinstance(replay_buffer, list):
        images = replay_buffer[0]
    elif isinstance(replay_buffer, tuple):
        images = replay_buffer[0]
    else:
        images = replay_buffer

    feed_imgs = []
    for img in images:
        n_img = norm_ip(img, -1, 1)
        new_img = n_img.cpu().numpy().transpose(1, 2, 0) * 255
        feed_imgs.append(new_img)
    feed_imgs = np.stack(feed_imgs)
    
    def rescale_im(im):
        return np.clip(im * 256, 0, 255).astype(np.uint8)

    test_ims = []
    for data, _ in test_loader:
        data = data.numpy()
        data = data.transpose(0,2,3,1)*255
        test_ims.extend(list(rescale_im(data)))

    # FID score
    # n = min(len(images), len(test_ims))
    fid = -1
    print(feed_imgs.shape, len(test_ims), test_ims[0].shape)

    from evaluate.gen.fid import get_fid_score
    from evaluate.gen.inception_tf import get_inception_score
    if eval in ['fid', 'all']:
        start = time.time()
        fid = get_fid_score(feed_imgs, test_ims)
        # print("FID of score {} takes {}s".format(fid, time.time() - start))
    score, std = 0, 0
    if eval in ['is', 'all']:
        splits = max(1, len(feed_imgs) // 5000)
        start = time.time()
        score, std = get_inception_score(feed_imgs, splits=splits)
        # print("Inception score of {} with std of {} takes {}s".format(score, std, time.time() - start))
    return score, std, fid